from data_configs import DATASETS
import argparse
import numpy as np
import json
from tqdm import tqdm
import os
import re
import pickle
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import random


# refer to LEO: embodied-generalist
# https://github.com/embodied-generalist/embodied-generalist/blob/477dc44b8b18dbfbe6823c307436d896ec8b062e/evaluator/scanqa_eval.py#L41-L50
def answer_match(pred, gts):
    # return EM and refined EM
    if len(pred) == 0:
        return 0, 0
    if pred in gts:
        return 1, 1
    for gt in gts:
        if ''.join(pred.split()) in ''.join(gt.split()) or ''.join(gt.split()) in ''.join(pred.split()):
            return 0, 1
    return 0, 0


# refer to LEO: embodied-generalist
# https://github.com/embodied-generalist/embodied-generalist/blob/477dc44b8b18dbfbe6823c307436d896ec8b062e/data/data_utils.py#L322-L379
def clean_answer(data):
    data = data.lower()
    data = re.sub('[ ]+$', '', data)
    data = re.sub('^[ ]+', '', data)
    data = re.sub(' {2,}', ' ', data)

    data = re.sub('\.[ ]{2,}', '. ', data)
    data = re.sub('[^a-zA-Z0-9,\'\s\-:]+', '', data)
    data = re.sub('ç', 'c', data)
    data = re.sub('’', '\'', data)
    data = re.sub(r'\bletf\b', 'left', data)
    data = re.sub(r'\blet\b', 'left', data)
    data = re.sub(r'\btehre\b', 'there', data)
    data = re.sub(r'\brigth\b', 'right', data)
    data = re.sub(r'\brght\b', 'right', data)
    data = re.sub(r'\bbehine\b', 'behind', data)
    data = re.sub(r'\btv\b', 'TV', data)
    data = re.sub(r'\bchai\b', 'chair', data)
    data = re.sub(r'\bwasing\b', 'washing', data)
    data = re.sub(r'\bwaslked\b', 'walked', data)
    data = re.sub(r'\boclock\b', 'o\'clock', data)
    data = re.sub(r'\bo\'[ ]+clock\b', 'o\'clock', data)

    # digit to word, only for answer
    data = re.sub(r'\b0\b', 'zero', data)
    data = re.sub(r'\bnone\b', 'zero', data)
    data = re.sub(r'\b1\b', 'one', data)
    data = re.sub(r'\b2\b', 'two', data)
    data = re.sub(r'\b3\b', 'three', data)
    data = re.sub(r'\b4\b', 'four', data)
    data = re.sub(r'\b5\b', 'five', data)
    data = re.sub(r'\b6\b', 'six', data)
    data = re.sub(r'\b7\b', 'seven', data)
    data = re.sub(r'\b8\b', 'eight', data)
    data = re.sub(r'\b9\b', 'nine', data)
    data = re.sub(r'\b10\b', 'ten', data)
    data = re.sub(r'\b11\b', 'eleven', data)
    data = re.sub(r'\b12\b', 'twelve', data)
    data = re.sub(r'\b13\b', 'thirteen', data)
    data = re.sub(r'\b14\b', 'fourteen', data)
    data = re.sub(r'\b15\b', 'fifteen', data)
    data = re.sub(r'\b16\b', 'sixteen', data)
    data = re.sub(r'\b17\b', 'seventeen', data)
    data = re.sub(r'\b18\b', 'eighteen', data)
    data = re.sub(r'\b19\b', 'nineteen', data)
    data = re.sub(r'\b20\b', 'twenty', data)
    data = re.sub(r'\b23\b', 'twenty-three', data)

    # misc
    # no1, mat2, etc
    data = re.sub(r'\b([a-zA-Z]+)([0-9])\b', r'\g<1>', data)
    data = re.sub(r'\ba\b ([a-zA-Z]+)', r'\g<1>', data)
    data = re.sub(r'\ban\b ([a-zA-Z]+)', r'\g<1>', data)
    data = re.sub(r'\bthe\b ([a-zA-Z]+)', r'\g<1>', data)

    data = re.sub(r'\bbackwards\b', 'backward', data)

    return data


VIDEO_INFO_CACHE = {}


def get_args():
    parser = argparse.ArgumentParser(
        description='Evaluation for training-free video temporal grounding (Single GPU Version)')
    parser.add_argument('--dataset', default='charades', type=str, help='Specify the dataset.')
    parser.add_argument('--split', default='default', type=str, help='Specify the split.')
    parser.add_argument("--model_base", type=str, default="/path/to/qwen-model")
    parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
    parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory to save checkpoints")
    parser.add_argument("--resume", action="store_true", help="Resume from checkpoint")
    parser.add_argument("--device", type=str, default="cuda:0", help="GPU device to use")
    return parser.parse_args()


def calc_iou(candidates, gt):
    start, end = candidates[:, 0], candidates[:, 1]
    s, e = gt[0], gt[1]
    inter = np.minimum(end, e) - np.maximum(start, s)
    union = np.maximum(end, e) - np.minimum(start, s)
    return inter.clip(min=0) / union


def cached_process_vision_info(messages, return_video_kwargs=False):
    global VIDEO_INFO_CACHE

    video_path = None
    for msg in messages:
        for content in msg.get('content', []):
            if isinstance(content, dict) and 'video' in content:
                video_path = content['video']
                break

    cache_key = f"{video_path}"
    if cache_key in VIDEO_INFO_CACHE:
        return VIDEO_INFO_CACHE[cache_key]

    result = process_vision_info(messages, return_video_kwargs=return_video_kwargs)
    VIDEO_INFO_CACHE[cache_key] = result

    return result


def inference(video_path, prompt, model, processor, max_new_tokens=2048, device="cuda:0"):
    messages = [
        {
            "role":
                "user",
            "content": [
                {
                    "type": "text",
                    "text": prompt
                },
                {
                    "video": video_path,
                    "total_pixels": 3584 * 28 * 28,
                    "min_pixels": 16 * 28 * 28,
                },
            ]
        },
    ]
    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    image_inputs, video_inputs, video_kwargs = cached_process_vision_info(messages, return_video_kwargs=True)
    fps_inputs = video_kwargs['fps']

    inputs = processor(text=[text],
                       images=image_inputs,
                       videos=video_inputs,
                       fps=fps_inputs,
                       padding=True,
                       return_tensors="pt")
    inputs = inputs.to(device)

    with torch.no_grad():
        output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)

    generated_ids = [output_ids[i][len(inputs.input_ids[i]):] for i in range(len(output_ids))]
    output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    return output_text[0]


def parse_timestamp_output(output_string):
    matches = re.findall(r"(\d+\.?\d*) (to|and) (\d+\.?\d*)", output_string)
    if not matches:
        answer_match = re.search(r"<answer>(.*?)</answer>", output_string)
        if answer_match:
            answer_content = answer_match.group(1).strip()
            answer_matches = re.findall(r"(\d+\.?\d*) (to|and) (\d+\.?\d*)", answer_content)
            if answer_matches:
                last_match = answer_matches[-1]
                return float(last_match[0]), float(last_match[2])
        return None, None

    last_match = matches[-1]
    start_time_str = last_match[0]
    end_time_str = last_match[2]

    try:
        start_time = float(start_time_str)
        end_time = float(end_time_str)
        return start_time, end_time
    except ValueError:
        return None, None


def get_sqa_question_type(question):
    question = question.lstrip()
    if question[:4].lower() == 'what':
        return 0
    elif question[:2].lower() == 'is':
        return 1
    elif question[:3].lower() == 'how':
        return 2
    elif question[:3].lower() == 'can':
        return 3
    elif question[:5].lower() == 'which':
        return 4
    else:
        return 5     # others


# GROUND_TEMPLATE = """To accurately pinpoint the event "[EVENT]" in the video, determine the precise time period of the event.

# Output your thought process within the <think> </think> tags, including analysis with either specific timestamps (xx.xx) or time ranges (xx.xx to xx.xx) in <timestep> </timestep> tags.

# Then, provide the start and end times (in seconds, precise to two decimal places) in the format "start time to end time" within the <answer> </answer> tags. For example: "12.54 to 17.83"."""

# QUESTION_TEMPLATE = """To accurately pinpoint the object described as "[EVENT]" in the video, determine the precise time period of the occurance of the object.

# Output your thought process within the <think> </think> tags, including analysis with either specific timestamps (xx.x) or time ranges (x.xx to xx.x) in <timestep> </timestep> tags.

# Then, provide the start and end times (in seconds, precise to one decimal places) in the format "start time to end time" within the <answer> </answer> tags. For example: "12.5 to 17.0"."""

# GROUND_TEMPLATE = """To accurately pinpoint the event "[EVENT]" in the video, determine the precise time period of the event.

# Provide the start and end times (in seconds, precise to two decimal places) in the format "start time to end time" within the <answer> </answer> tags. For example: "12.54 to 17.83"."""

QUESTION_TEMPLATE = """"[EVENT]" Answer the question using a single word or phrase. Output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."""


def create_work_items(data):
    work_items = []

    for vid, ann in data.items():
        for i in range(len(ann['sentences'])):
            work_items.append({'vid': vid, 'ann': ann, 'sentence_idx': i})
    random.shuffle(work_items)
    return work_items


def setup_model(model_base, device):
    print(f"Setting up model on device {device}")
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_base,
                                                               torch_dtype=torch.bfloat16,
                                                               use_sliding_window=True,
                                                               attn_implementation="flash_attention_2",
                                                               device_map=device)
    processor = AutoProcessor.from_pretrained(model_base)
    return model, processor


def get_checkpoint_path(checkpoint_dir):
    os.makedirs(checkpoint_dir, exist_ok=True)
    return os.path.join(checkpoint_dir, "checkpoint.pkl")


def load_checkpoint(checkpoint_path):
    if os.path.exists(checkpoint_path):
        try:
            with open(checkpoint_path, 'rb') as f:
                return pickle.load(f)
        except Exception as e:
            print(f"Error loading checkpoint: {e}")
    return {'processed_items': set(), 'ious': [], 'recall': np.array([0, 0, 0])}


def save_checkpoint(checkpoint_path, state):
    with open(checkpoint_path, 'wb') as f:
        pickle.dump(state, f)


def process_work_items(work_items, video_dir_path, model_base, device, checkpoint_dir, resume=False):
    ious = []
    thresh = np.array([0.3, 0.5, 0.7])
    recall = np.array([0, 0, 0])

    em = []
    em_refined = []
    preds = []

    model, processor = setup_model(model_base, device)

    # item_ids = [f"{item['vid']}_{item['sentence_idx']}" for item in work_items]
    # remaining_items = [(i, item)
    #                    for i, (item, item_id) in enumerate(zip(work_items, item_ids))
    #                    if not resume or item_id not in processed_items]

    # if not remaining_items:
    #     print("All items already processed")
    #     return ious, recall

    # print(f"Processing {len(remaining_items)} out of {len(work_items)} items")

    pbar = tqdm(work_items)
    for idx, item in enumerate(pbar):
        vid = item['scene_id']
        ann = item
        # item_id = idx
        # sentence_idx = item['sentence_idx']
        # item_id = f"{vid}_{sentence_idx}"

        prompt = QUESTION_TEMPLATE.replace('[EVENT]', f"{ann['situation']} {ann['question']}")

        # duration = ann['duration'] if 'duration' in ann else ann['video_duration']
        video_path = None
        for ext in ['mp4', 'mkv', 'webm']:
            path = os.path.join(video_dir_path, f"{vid}.{ext}")
            if os.path.isfile(path):
                video_path = path
                break

        if video_path:
            try:
                content = inference(video_path, prompt, model, processor, device=device)
                # print('[prompt]', prompt)
                # print('[answer]', ans)

                content_match = re.search(r'<answer>(.*?)</answer>', content, re.DOTALL)
                pred = content_match.group(1).strip() if content_match else content.strip()
                sol = ann['answers']

                if len(pred) > 1:
                    if pred[-1] == '.':
                        pred = pred[:-1]
                    pred = pred[0].lower() + pred[1:]
                pred = clean_answer(pred)
                ref_captions = [clean_answer(s) for s in sol]

                em_flag, em_refined_flag = answer_match(pred, ref_captions)
                em.append(em_flag)
                em_refined.append(em_refined_flag)

                # processed_items.add(item_id)

                # if (idx + 1) % 5 == 0 or idx == len(remaining_items) - 1:
                #     state = {'processed_items': processed_items, 'ious': ious, 'recall': recall}
                #     save_checkpoint(checkpoint_path, state)

                running_em = sum(em) / len(em) if em else 0
                running_em_r = sum(em_refined) / len(em_refined) if em_refined else 0
                pbar.set_postfix({"EM": running_em, "EM_refined": running_em_r})
                print(vid, content)

                preds.append({
                    'video_id': vid,
                    'question': f"{ann['situation']} {ann['question']}",
                    'response': content,
                    'pred': pred,
                    'gt': ref_captions,
                    'em': em_flag,
                    'em_refined': em_refined_flag,
                })

                with open(os.path.join(checkpoint_dir, 'preds.json'), 'w') as f:
                    json.dump(preds, f, indent=4)

                # miou = sum(ious) / len(ious) if ious else 0
                # recall_str = str(recall / len(ious) if ious else [0, 0, 0])
                # pbar.set_postfix({"mIoU": miou, 'recall': recall_str})

            except Exception as e:
                print(f"Error processing {item}")

        # break

    print('=== final result ===')
    # if ious:
    # print('mIoU:', sum(ious) / len(ious))
    # for th, r in zip(thresh, recall):
    #     print(f'R@{th}:', r / len(ious))
    print('EM:', sum(em) / len(em))
    print('EM_refined:', sum(em_refined) / len(em_refined))

    return em, em_refined


def evaluate(data, args):
    dataset = DATASETS[args.dataset]
    video_dir_path = dataset['video_path']

    # work_items = create_work_items(data)

    em, em_refined = process_work_items(data, video_dir_path, args.model_base, args.device, args.checkpoint_dir,
                                        args.resume)

    return em, em_refined


if __name__ == '__main__':
    args = get_args()
    assert args.dataset in DATASETS
    dataset = DATASETS[args.dataset]
    assert args.split in dataset['splits']

    print('evaluate', args.dataset, args.split)

    # load data
    with open(dataset['splits'][args.split]['annotation_file']) as f:
        data = json.load(f)

    evaluate(data, args)
